{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "638daffa",
   "metadata": {},
   "source": [
    "# Prototype Selection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c4b5f11",
   "metadata": {},
   "source": [
    "In this notebook, we show an example of selecting prototypical examples from the source dataset which are representative of the target dataset. We experiment with the popular [digit dataset](https://scikit-learn.org/stable/auto_examples/datasets/plot_digits_last_image.html). Two partitions are randomly created, **src** and **tgt**, which correspond to the source and target sets, respectively. [Our approach](https://link.springer.com/chapter/10.1007/978-3-030-86514-6_33) exploits the optimal tranpot theory to learn prototypes from **src** by matching the prototype distribution with the target **tgt** distribution.\n",
    "\n",
    "This notebook can be found in our [**_examples folder_**](https://github.com/interpretml/interpret/tree/develop/docs/interpret/python/examples) on GitHub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6025f59e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# install interpret if not already installed\n",
    "try:\n",
    "    import interpret\n",
    "except ModuleNotFoundError:\n",
    "    !pip install --quiet interpret numpy scikit-learn matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7712fae6",
   "metadata": {},
   "source": [
    "We load the required packages. The package/file specific to the prototype selection algorithm is \"SPOTgreedy\". "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34efebf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_digits\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import pairwise_distances\n",
    "import numpy as np  \n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import pairwise_distances\n",
    "from interpret.utils import SPOT_GreedySubsetSelection # This loads the SPOT prototype selection algorithm."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec400f24",
   "metadata": {},
   "source": [
    "We now load the digit dataset and create **src** and **tgt** sets by splitting the digit data into 70/30 partitions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bed6a55b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the digits dataset\n",
    "digits = load_digits()\n",
    "\n",
    "# Flatten the images\n",
    "n_samples = len(digits.images)\n",
    "data = digits.images.reshape((n_samples, -1))\n",
    "\n",
    "# Split data into 70% src and 30% tgt subsets \n",
    "X_src, X_tgt, y_src, y_tgt = train_test_split(\n",
    "    data, digits.target, test_size=0.3, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31905373",
   "metadata": {},
   "source": [
    "Pairwise distances/dissimilarities between the source and target points are required. The optimal transport framework allows the use of any distance/dissimilarity measure. In this example, we use the Euclidean distance metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "051e20b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the Euclidean distances between the X_src (source) and X_tgt (target) points.\n",
    "C = pairwise_distances(X_src, X_tgt, metric='euclidean');\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6837b0a",
   "metadata": {},
   "source": [
    "`targetmarginal` is the empirical distribution over the target points. It is usually taken to be uniform, i.e., every target point is given equal importance. For the experiments, we discuss two settings. In the first setting, we take `targetmarginal` to be uniform. In the second setting, we skew `targetmarginal` against points of a particular class. The experiments show that in both of these settings, the learnt prototypes nicely represent the target distribution `targetmarginal`. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2699d84",
   "metadata": {},
   "source": [
    "**Setting 1: target distribution is uniform**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43fa0e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a targetmarginal on the target set\n",
    "# We define the uniform marginal\n",
    "targetmarginal = np.ones(C.shape[1])/C.shape[1];\n",
    "\n",
    "\n",
    "# The number of prototypes to be computed\n",
    "numprototypes = 20;\n",
    "\n",
    "# Run SPOTgreedy\n",
    "# prototypeIndices represent the indices corresponding to the chosen prototypes.\n",
    "# prototypeWeights represent the weights associated with each of the chosen prototypes. The weights sum to 1. \n",
    "[prototypeIndices, prototypeWeights] = SPOT_GreedySubsetSelection(C, targetmarginal, numprototypes);\n",
    "\n",
    "# Plot the chosen prototypes\n",
    "fig, axs = plt.subplots(nrows=5, ncols=4, figsize=(2, 2))\n",
    "for idx, ax in enumerate(axs.ravel()):\n",
    "    ax.imshow(data[prototypeIndices[idx]].reshape((8, 8)), cmap=plt.cm.binary)\n",
    "    ax.axis(\"off\")\n",
    "_ = fig.suptitle(\"Top prototypes selected from the 64-dimensional digit dataset with uniform target distribution\", fontsize=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "457bca06",
   "metadata": {},
   "source": [
    "**Setting 2: target distribution is skewed**\n",
    "\n",
    "In this setting, we skew the examples in **tgt** corresponding to the label `3` by 90%. We expect that a large majority of the learnt prototypes also belong the label `3`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e434c6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Skew the target marginal to give weights to specific classes more\n",
    "result = np.where(y_tgt == 3); # find indices corresponding to label 3.\n",
    "\n",
    "\n",
    "targetmarginal_skewed = np.ones(C.shape[1]);\n",
    "targetmarginal_skewed[result[0]] = 90; # Weigh the instances corresponding to label 3 more.\n",
    "targetmarginal_skewed = targetmarginal_skewed/np.sum(targetmarginal_skewed);\n",
    "\n",
    "# Run SPOTgreedy\n",
    "[prototypeIndices_skewed, prototypeWeights_skewed] = SPOT_GreedySubsetSelection(C, targetmarginal_skewed, numprototypes);\n",
    "\n",
    "# Plot the prototypes selected\n",
    "fig, axs = plt.subplots(nrows=5, ncols=4, figsize=(2, 2))\n",
    "for idx, ax in enumerate(axs.ravel()):\n",
    "    ax.imshow(data[prototypeIndices_skewed[idx]].reshape((8, 8)), cmap=plt.cm.binary)\n",
    "    ax.axis(\"off\")\n",
    "_ = fig.suptitle(\"Top prototypes selected from the 64-dimensional digit dataset with skewed target distribution\", fontsize=16)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
